jax>=0.3.4
numpy>=1.21
flax>=0.4
matplotlib>=3.5
tensorflow>=2.8
tensorflow-datasets>=4.5
optax>=0.1.1
tree-math>=0.1
